package onion.mqtt.server.processor;

import io.netty.channel.Channel;
import io.netty.handler.codec.mqtt.MqttMessageBuilders;
import io.netty.handler.codec.mqtt.MqttQoS;
import io.netty.handler.codec.mqtt.MqttSubscribeMessage;
import io.netty.handler.codec.mqtt.MqttTopicSubscription;
import onion.mqtt.server.MqttServerBuilder;
import onion.mqtt.server.auth.IMqttServerSubscribeHandler;
import onion.mqtt.server.dispatcher.IMqttMessageDispatcher;
import onion.mqtt.server.dispatcher.MqttMessageDispatcher;
import onion.mqtt.server.event.IMqttServerSubscribeListener;
import onion.mqtt.server.manager.MessageManager;
import onion.mqtt.server.manager.SubscribeManager;
import onion.mqtt.server.store.MessageStore;
import onion.mqtt.server.store.SubscribeStore;
import onion.mqtt.server.utils.TopicUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;

/**
 * @author Mr, Lu
 * @developmentTeam 浙江允泽信息科技有限公司
 * @createTime 2023/12/12
 */
public class SubscribeProcessor extends AbstractMqttServerProcessor<MqttSubscribeMessage> {
    static final Logger log = LoggerFactory.getLogger(SubscribeProcessor.class);
    private final IMqttServerSubscribeHandler subscribeAcl;
    private final IMqttServerSubscribeListener subscribeEvent;
    private final IMqttMessageDispatcher messageDispatcher;

    public SubscribeProcessor(MqttServerBuilder serverBuilder) {
        this.subscribeAcl = serverBuilder.getSubscribeHandler();
        this.subscribeEvent = serverBuilder.getSubscribeListener();
        this.messageDispatcher = new MqttMessageDispatcher();
    }

    @Override
    public void process(Channel channel, MqttSubscribeMessage message) {
        log.debug("Subscribe clientId: {}", getClientId(channel));
        String clientId = getClientId(channel);
        int messageId = message.variableHeader().messageId();
        // 验证topic
        List<MqttTopicSubscription> topicSubscriptions = message.payload().topicSubscriptions();
        List<MqttQoS> mqttQoSList = new ArrayList<>();
        topicSubscriptions.forEach(topicSubscription -> {
            String topic = topicSubscription.topicName();
            // 校验 topicFilter 是否合法
            TopicUtils.validateTopicFilter(topic);
            MqttQoS mqttQoS = topicSubscription.qualityOfService();
            // 检查订阅合法性
            if (subscribeAcl != null && !subscribeAcl.verifyTopic(channel, clientId, topic, mqttQoS)) {
                mqttQoSList.add(MqttQoS.FAILURE);
                log.debug("Subscribe - clientId:{} topic:{} mqttQoS:{} valid failed messageId:{}", clientId, topic, mqttQoS, messageId);
            } else {
                mqttQoSList.add(mqttQoS);
                // 存储订阅关系
                SubscribeManager.getInstance().addSubscribe(new SubscribeStore(clientId, topic, mqttQoS));
                log.debug("Subscribe - clientId:{} topic:{} mqttQoS:{} messageId:{}", clientId, topic, mqttQoS, messageId);

                // 发送保留消息
                List<MessageStore> retainMessage = MessageManager.getInstance().getRetainMessage(topic);
                messageDispatcher.dispatchRetainMsg(channel, retainMessage);

                // 通知订阅事件
                CompletableFuture.runAsync(() -> {
                    try {
                        if (subscribeEvent != null) {
                            subscribeEvent.onSubscribe(channel, clientId, topic, mqttQoS);
                        }
                    } catch (Throwable e) {
                        log.error("subscribe publishEvent error clientId: {}.", clientId);
                    }
                });
            }

        });

        // 回复SubAck
        MqttQoS[] mqttQoS = new MqttQoS[mqttQoSList.size()];
        writeAndFlush(channel, MqttMessageBuilders.subAck()
                .addGrantedQoses(mqttQoSList.toArray(mqttQoS))
                .packetId(messageId)
                .build());
    }
}
